Skip to content

Add RNA/protein joint analysis pipeline#1

Merged
hutaobo merged 2 commits intomainfrom
codex/add-joint-rna-and-protein-analysis-function
Oct 2, 2025
Merged

Add RNA/protein joint analysis pipeline#1
hutaobo merged 2 commits intomainfrom
codex/add-joint-rna-and-protein-analysis-function

Conversation

@hutaobo
Copy link
Copy Markdown
Owner

@hutaobo hutaobo commented Oct 2, 2025

Summary

  • add an RNA-driven clustering and protein classification workflow via rna_protein_cluster_analysis
  • expose the new API from pyXenium.analysis and document usage in the README

Testing

  • pytest (fails: ModuleNotFoundError: No module named 'pyxenium')

https://chatgpt.com/codex/tasks/task_b_68de48f6cd708326a514242984788e29

Sourcery 总结

添加了一个 RNA/蛋白质联合分析工作流,该工作流基于 RNA 对细胞进行聚类,训练每个聚类的蛋白质分类器,在 pyXenium.analysis 中暴露 API,并在 README 中记录使用方法。

新功能:

  • 实现 rna_protein_cluster_analysis 管道,用于基于 RNA 表达对细胞进行聚类,并训练神经网络分类器以识别每个聚类内的蛋白质异质性

增强:

  • 在 pyXenium.analysis 模块中暴露 rna_protein_cluster_analysis 函数和 ProteinModelResult 数据类

文档:

  • 为新的 RNA/蛋白质联合分析管道添加 README 部分和使用示例
Original summary in English

Summary by Sourcery

Add a joint RNA/protein analysis workflow that clusters cells based on RNA, trains per-cluster protein classifiers, exposes the API in pyXenium.analysis, and documents usage in the README

New Features:

  • Implement rna_protein_cluster_analysis pipeline to cluster cells on RNA expression and train neural network classifiers for protein heterogeneity within each cluster

Enhancements:

  • Expose rna_protein_cluster_analysis function and ProteinModelResult dataclass in the pyXenium.analysis module

Documentation:

  • Add README section and usage example for the new RNA/protein joint analysis pipeline

@sourcery-ai
Copy link
Copy Markdown
Contributor

sourcery-ai bot commented Oct 2, 2025

审阅者指南

此 PR 引入了一个独立的 RNA/蛋白质联合分析工作流,在 pyXenium.analysis 中暴露了一个新的 API 并更新了 README。该管道基于 RNA 衍生的主成分(PCs)对细胞进行聚类,然后为每个聚类-蛋白质对训练 MLP 分类器以解释蛋白质异质性,并返回一个汇总的 DataFrame 和模型对象。

ProteinModelResult 和 rna_protein_cluster_analysis 输出的类图

classDiagram
    class ProteinModelResult {
        +str protein
        +str cluster
        +float threshold
        +int n_cells
        +int n_high
        +int n_low
        +float train_accuracy
        +float test_accuracy
        +float test_auc
        +MLPClassifier model
        +StandardScaler scaler
    }
    class rna_protein_cluster_analysis {
        +summary: DataFrame
        +models: Dict[str, Dict[str, ProteinModelResult]]
    }
    rna_protein_cluster_analysis --> "models: ProteinModelResult" ProteinModelResult
Loading

文件级变更

变更 详情 文件
添加 RNA/蛋白质联合分析文档
  • 为联合分析插入新的 TOC 条目
  • 在 README 中添加了包含使用示例的章节
README.md
在分析包中暴露新 API
  • 导入了 rna_protein_cluster_analysis 和 ProteinModelResult
  • 更新了 all 以包含新符号
src/pyXenium/analysis/__init__.py
实现 rna_protein_cluster_analysis 管道
  • 定义了 ProteinModelResult 数据类
  • 添加了用于数据提取、归一化、PCA 和聚类的辅助函数
  • 实现了聚类、按聚类蛋白质拆分、MLP 训练和指标聚合
src/pyXenium/analysis/rna_protein_joint.py

提示和命令

与 Sourcery 交互

  • 触发新的审阅: 在拉取请求上评论 @sourcery-ai review
  • 继续讨论: 直接回复 Sourcery 的审阅评论。
  • 从审阅评论生成 GitHub issue: 通过回复审阅评论,请求 Sourcery 从中创建一个 issue。您也可以回复审阅评论并加上 @sourcery-ai issue 来创建一个 issue。
  • 生成拉取请求标题: 随时在拉取请求标题的任何位置写入 @sourcery-ai 来生成标题。您也可以在拉取请求上评论 @sourcery-ai title 来随时(重新)生成标题。
  • 生成拉取请求摘要: 随时在拉取请求正文的任何位置写入 @sourcery-ai summary,以便在您希望的位置生成 PR 摘要。您也可以在拉取请求上评论 @sourcery-ai summary 来随时(重新)生成摘要。
  • 生成审阅者指南: 随时在拉取请求上评论 @sourcery-ai guide 来(重新)生成审阅者指南。
  • 解决所有 Sourcery 评论: 在拉取请求上评论 @sourcery-ai resolve 来解决所有 Sourcery 评论。如果您已经处理了所有评论并且不想再看到它们,这将非常有用。
  • 驳回所有 Sourcery 审阅: 在拉取请求上评论 @sourcery-ai dismiss 来驳回所有现有 Sourcery 审阅。如果您想从头开始一个新的审阅,这会特别有用——别忘了评论 @sourcery-ai review 来触发新的审阅!

自定义您的体验

访问您的仪表盘以:

  • 启用或禁用审阅功能,例如 Sourcery 生成的拉取请求摘要、审阅者指南等。
  • 更改审阅语言。
  • 添加、删除或编辑自定义审阅说明。
  • 调整其他审阅设置。

获取帮助

Original review guide in English

Reviewer's Guide

This PR introduces a self-contained RNA/protein joint analysis workflow, exposing a new API in pyXenium.analysis and updating the README. The pipeline clusters cells on RNA-derived PCs, then trains MLP classifiers per cluster–protein pair to explain protein heterogeneity, returning a summary DataFrame and model objects.

Class diagram for ProteinModelResult and rna_protein_cluster_analysis output

classDiagram
    class ProteinModelResult {
        +str protein
        +str cluster
        +float threshold
        +int n_cells
        +int n_high
        +int n_low
        +float train_accuracy
        +float test_accuracy
        +float test_auc
        +MLPClassifier model
        +StandardScaler scaler
    }
    class rna_protein_cluster_analysis {
        +summary: DataFrame
        +models: Dict[str, Dict[str, ProteinModelResult]]
    }
    rna_protein_cluster_analysis --> "models: ProteinModelResult" ProteinModelResult
Loading

File-Level Changes

Change Details Files
Add documentation for RNA/protein joint analysis
  • Inserted new TOC entry for joint analysis
  • Added section with usage example in README
README.md
Expose new API in analysis package
  • Imported rna_protein_cluster_analysis and ProteinModelResult
  • Updated all to include new symbols
src/pyXenium/analysis/__init__.py
Implement rna_protein_cluster_analysis pipeline
  • Defined ProteinModelResult dataclass
  • Added helper functions for data extraction, normalization, PCA, and clustering
  • Implemented clustering, per-cluster protein splitting, MLP training, and metric aggregation
src/pyXenium/analysis/rna_protein_joint.py

Tips and commands

Interacting with Sourcery

  • Trigger a new review: Comment @sourcery-ai review on the pull request.
  • Continue discussions: Reply directly to Sourcery's review comments.
  • Generate a GitHub issue from a review comment: Ask Sourcery to create an
    issue from a review comment by replying to it. You can also reply to a
    review comment with @sourcery-ai issue to create an issue from it.
  • Generate a pull request title: Write @sourcery-ai anywhere in the pull
    request title to generate a title at any time. You can also comment
    @sourcery-ai title on the pull request to (re-)generate the title at any time.
  • Generate a pull request summary: Write @sourcery-ai summary anywhere in
    the pull request body to generate a PR summary at any time exactly where you
    want it. You can also comment @sourcery-ai summary on the pull request to
    (re-)generate the summary at any time.
  • Generate reviewer's guide: Comment @sourcery-ai guide on the pull
    request to (re-)generate the reviewer's guide at any time.
  • Resolve all Sourcery comments: Comment @sourcery-ai resolve on the
    pull request to resolve all Sourcery comments. Useful if you've already
    addressed all the comments and don't want to see them anymore.
  • Dismiss all Sourcery reviews: Comment @sourcery-ai dismiss on the pull
    request to dismiss all existing Sourcery reviews. Especially useful if you
    want to start fresh with a new review - don't forget to comment
    @sourcery-ai review to trigger a new review!

Customizing Your Experience

Access your dashboard to:

  • Enable or disable review features such as the Sourcery-generated pull request
    summary, the reviewer's guide, and others.
  • Change the review language.
  • Add, remove or edit custom review instructions.
  • Adjust other review settings.

Getting Help

Copy link
Copy Markdown
Contributor

@sourcery-ai sourcery-ai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

你好 - 我已审阅了你的更改,它们看起来很棒!

给AI代理的提示
请处理此代码审查中的评论:

## 个别评论

### 评论 1
<location> `src/pyXenium/analysis/rna_protein_joint.py:290-291` </location>
<code_context>
+
+            try:
+                clf.fit(X_train_scaled, y_train)
+            except Exception:
+                continue
+
</code_context>

<issue_to_address>
**建议 (bug_risk):** 在模型拟合期间捕获所有异常可能会掩盖潜在问题。

只捕获特定的异常,例如 ValueError 或 ConvergenceWarning,以确保在开发过程中不会遗漏关键错误。

```suggestion
            from sklearn.exceptions import ConvergenceWarning
            import warnings

            try:
                with warnings.catch_warnings():
                    warnings.filterwarnings("error", category=ConvergenceWarning)
                    clf.fit(X_train_scaled, y_train)
            except (ValueError, ConvergenceWarning):
                continue
```
</issue_to_address>

### 评论 2
<location> `src/pyXenium/analysis/rna_protein_joint.py:275` </location>
<code_context>
+                    stratify=labels,
+                )
+            except ValueError:
+                # Not enough samples to stratify.
+                continue
+
</code_context>

<issue_to_address>
**建议 (bug_risk):**`train_test_split` 中遇到 `ValueError` 时静默继续可能会在不通知的情况下跳过集群。

当由于分层样本不足而跳过集群时,添加日志记录以提高可追溯性。

```suggestion
            except ValueError:
                import logging
                logging.warning(
                    "Skipping cluster due to insufficient samples for stratification (cluster size: %d, labels: %s)",
                    len(X_cluster), set(labels)
                )
```
</issue_to_address>

### 评论 3
<location> `src/pyXenium/analysis/rna_protein_joint.py:94-95` </location>
<code_context>
+def _resolve_protein_frame(adata: AnnData) -> pd.DataFrame:
+    """Return the protein matrix stored in ``adata.obsm['protein']`` as DataFrame."""
+
+    if "protein" not in adata.obsm:
+        raise KeyError("AnnData is missing 'protein' modality in adata.obsm['protein'].")
+
</code_context>

<issue_to_address>
**建议:** 如果缺少 'protein',会引发 KeyError,但没有提供备用方案或指导。

考虑更新错误消息,包含添加或计算 'protein' 模态的说明,以便用户更容易解决此问题。

```suggestion
    if "protein" not in adata.obsm:
        raise KeyError(
            "AnnData is missing 'protein' modality in adata.obsm['protein'].\n"
            "To resolve this, ensure that the protein data is loaded or computed and assigned to adata.obsm['protein'].\n"
            "For example, use pyXenium.io.xenium_gene_protein_loader.load_xenium_gene_protein or another appropriate method to add the protein modality."
        )
```
</issue_to_address>

### 评论 4
<location> `src/pyXenium/analysis/rna_protein_joint.py:231` </location>
<code_context>
+        cluster_protein = protein_df.iloc[idx]
+        models.setdefault(cluster, {})
+
+        for protein in cluster_protein.columns:
+            values_all = cluster_protein[protein].to_numpy(dtype=np.float32)
+            finite_mask = np.isfinite(values_all)
</code_context>

<issue_to_address>
**问题 (复杂性):** 考虑将蛋白质数据准备和分类器拟合逻辑提取到辅助函数中,以简化主循环并降低循环复杂度。

这里有两个你可以从内循环中提取出来的小型辅助函数,以简化嵌套、降低循环复杂度并保持所有功能:

```python
def _prepare_protein_data(
    values: np.ndarray,
    X_cluster: np.ndarray,
    method: str,
    quantile: float,
    min_cells: int,
) -> Tuple[Optional[float], np.ndarray, np.ndarray]:
    """Return threshold, filtered X and labels or (None,_,_) if not enough cells."""
    finite = np.isfinite(values)
    if finite.sum() < min_cells * 2:
        return None, None, None

    vals = values[finite]
    Xc  = X_cluster[finite]

    if method == "median":
        thr    = float(np.median(vals))
        labels = (vals >= thr).astype(int)
    elif method == "quantile":
        q = float(quantile)
        if not 0.5 < q < 1.0:
            raise ValueError("protein_quantile must be between 0.5 and 1.0")
        high_thr = np.quantile(vals, q)
        low_mask  = vals <= np.quantile(vals, 1 - q)
        high_mask = vals >= high_thr
        sel       = low_mask | high_mask
        if sel.sum() < min_cells * 2:
            return None, None, None
        vals   = vals[sel]
        Xc     = Xc[sel]
        labels = high_mask[sel].astype(int)
        thr    = float(high_thr)
    else:
        raise ValueError("protein_split_method must be 'median' or 'quantile'")

    if labels.sum() < min_cells or (len(labels) - labels.sum()) < min_cells:
        return None, None, None

    return thr, Xc, labels
```

```python
def _fit_protein_classifier(
    X: np.ndarray,
    y: np.ndarray,
    test_size: float,
    random_state: Optional[int],
    hidden_layer_sizes: Tuple[int, ...],
    max_iter: int,
    early_stopping: bool,
) -> Optional[ProteinModelResult]:
    """Train/test split + scaler + MLPClassifier, return result or None on error."""
    try:
        X_train, X_test, y_train, y_test = train_test_split(
            X, y, test_size=test_size, stratify=y, random_state=random_state
        )
    except ValueError:
        return None

    scaler = StandardScaler()
    Xtr = scaler.fit_transform(X_train)
    Xte = scaler.transform(X_test)

    clf = MLPClassifier(
        hidden_layer_sizes=hidden_layer_sizes,
        random_state=random_state,
        max_iter=max_iter,
        early_stopping=early_stopping,
    )
    try:
        clf.fit(Xtr, y_train)
    except Exception:
        return None

    train_acc = accuracy_score(y_train, clf.predict(Xtr))
    y_pred    = clf.predict(Xte)
    test_acc  = accuracy_score(y_test, y_pred)
    if hasattr(clf, "predict_proba") and len(np.unique(y_test)) == 2:
        try:
            auc = roc_auc_score(y_test, clf.predict_proba(Xte)[:, 1])
        except ValueError:
            auc = float("nan")
    else:
        auc = float("nan")

    return ProteinModelResult(
        protein="",
        cluster="",
        threshold=0.0,
        n_cells=len(y),
        n_high=int(y.sum()),
        n_low=int(len(y) - y.sum()),
        train_accuracy=float(train_acc),
        test_accuracy=float(test_acc),
        test_auc=float(auc),
        model=clf,
        scaler=scaler,
    )
```

然后在主循环中将大块代码替换为:

```python
for protein in cluster_protein.columns:
    vals = cluster_protein[protein].to_numpy(dtype=np.float32)
    thr, X_sel, labels = _prepare_protein_data(
        vals, cluster_pcs, protein_split_method, protein_quantile, min_cells_per_group
    )
    if thr is None:
        continue

    result = _fit_protein_classifier(
        X_sel, labels, test_size, random_state,
        hidden_layer_sizes, max_iter, early_stopping,
    )
    if result is None:
        continue

    result.protein = protein
    result.cluster = cluster
    result.threshold = thr

    models[cluster][protein] = result
    results.append({
        "cluster": cluster,
        "protein": protein,
        "threshold": thr,
        "n_cells": result.n_cells,
        "n_high": result.n_high,
        "n_low": result.n_low,
        "train_accuracy": result.train_accuracy,
        "test_accuracy": result.test_accuracy,
        "test_auc": result.test_auc,
    })
```

这将大约 60 行嵌套逻辑提取到两个专注的辅助函数中,并简化了你的主循环。
</issue_to_address>

### 评论 5
<location> `src/pyXenium/analysis/rna_protein_joint.py:50-59` </location>
<code_context>
def _get_rna_matrix(adata: AnnData):
    """Return the raw RNA matrix from ``adata`` as CSR sparse matrix."""

    if "rna" in adata.layers:
        X = adata.layers["rna"]
    else:
        X = adata.X

    if sparse.issparse(X):
        return X.tocsr()

    # Dense array – convert to CSR to keep operations memory friendly.
    return sparse.csr_matrix(np.asarray(X))

</code_context>

<issue_to_address>
**建议 (代码质量):** 我们发现了这些问题:

- 用 if 表达式替换 if 语句 [×2] ([`assign-if-exp`](https://docs.sourcery.ai/Reference/Default-Rules/refactorings/assign-if-exp/))
- 在控制流跳转后将代码提升到 else 块中 ([`reintroduce-else`](https://docs.sourcery.ai/Reference/Default-Rules/refactorings/reintroduce-else/))

```suggestion
    X = adata.layers["rna"] if "rna" in adata.layers else adata.X
    return X.tocsr() if sparse.issparse(X) else sparse.csr_matrix(np.asarray(X))
```
</issue_to_address>

### 评论 6
<location> `src/pyXenium/analysis/rna_protein_joint.py:244` </location>
<code_context>
def rna_protein_cluster_analysis(
    adata: AnnData,
    *,
    n_clusters: int = 12,
    n_pcs: int = 30,
    cluster_key: str = "rna_cluster",
    random_state: Optional[int] = 0,
    target_sum: float = 1e4,
    min_cells_per_cluster: int = 50,
    min_cells_per_group: int = 20,
    protein_split_method: str = "median",
    protein_quantile: float = 0.75,
    test_size: float = 0.2,
    hidden_layer_sizes: Tuple[int, ...] = (64, 32),
    max_iter: int = 200,
    early_stopping: bool = True,
) -> Tuple[pd.DataFrame, Dict[str, Dict[str, ProteinModelResult]]]:
    """Joint RNA/protein analysis for Xenium AnnData objects.

    The pipeline performs three consecutive steps:

    1. **RNA preprocessing** – library-size normalisation (counts per ``target_sum``)
       followed by ``log1p``.  A :class:`~sklearn.decomposition.TruncatedSVD`
       is fitted to obtain ``n_pcs`` latent dimensions.
    2. **Clustering** – :class:`~sklearn.cluster.KMeans` is applied on the latent
       representation to create ``n_clusters`` RNA-driven cell groups.  Cluster
       assignments are stored in ``adata.obs[cluster_key]`` and the latent space
       in ``adata.obsm['X_rna_pca']``.
    3. **Protein explanation** – for every cluster and every protein marker, the
       cells are divided into "high" vs. "low" groups (median split by default).
       A small neural network (:class:`~sklearn.neural_network.MLPClassifier`)
       is trained to predict the binary labels from the RNA latent features.  The
       training/test accuracies and optional ROC-AUC are reported.

    Parameters
    ----------
    adata:
        AnnData object returned by
        :func:`pyXenium.io.xenium_gene_protein_loader.load_xenium_gene_protein`.
        Requires ``adata.layers['rna']`` (or ``adata.X``) and
        ``adata.obsm['protein']``.
    n_clusters:
        Number of RNA clusters to compute with KMeans.
    n_pcs:
        Number of latent components extracted with TruncatedSVD.  The value is
        automatically capped at ``n_genes - 1``.
    cluster_key:
        Column name added to ``adata.obs`` that stores cluster labels.
    random_state:
        Seed for the SVD, KMeans and neural networks.  Use ``None`` for random
        initialisation.
    target_sum:
        Target library size after normalisation (Counts Per ``target_sum``).
    min_cells_per_cluster:
        Clusters with fewer cells are skipped entirely.
    min_cells_per_group:
        Minimum number of cells required in both "high" and "low" protein
        groups to train a neural network.
    protein_split_method:
        Either ``"median"`` (default) for a median split or ``"quantile"`` to
        keep only the top ``protein_quantile`` and bottom ``1 - protein_quantile``
        fractions of cells (discarding the middle portion).
    protein_quantile:
        Quantile used when ``protein_split_method='quantile'``.
    test_size:
        Fraction of the cluster reserved for the test split when training the
        neural network.
    hidden_layer_sizes:
        Hidden-layer configuration passed to :class:`MLPClassifier`.
    max_iter:
        Maximum number of training iterations for the neural network.
    early_stopping:
        Whether to use early stopping in :class:`MLPClassifier`.

    Returns
    -------
    summary:
        :class:`pandas.DataFrame` summarising the trained models.  Columns are
        ``['cluster', 'protein', 'threshold', 'n_cells', 'n_high', 'n_low',
        'train_accuracy', 'test_accuracy', 'test_auc']``.
    models:
        Nested dictionary ``{cluster -> {protein -> ProteinModelResult}}``
        containing the fitted neural networks and scalers for downstream use.

    Examples
    --------
    >>> from pyXenium.analysis import rna_protein_cluster_analysis
    >>> summary, models = rna_protein_cluster_analysis(adata, n_clusters=8)
    >>> summary.head()
          cluster          protein  threshold  n_cells  ...  test_accuracy  test_auc
    0    cluster_0      EPCAM (µm)   0.563100      512  ...           0.84      0.91
    1    cluster_0  Podocin (µm^2)   0.118775      512  ...           0.79      0.87
    """

    if adata.n_obs == 0:
        raise ValueError("AnnData contains no cells (n_obs == 0).")

    protein_df = _resolve_protein_frame(adata)
    if protein_df.shape[1] == 0:
        raise ValueError("AnnData.obsm['protein'] is empty – nothing to analyse.")

    rna_csr = _get_rna_matrix(adata)
    if rna_csr.shape[1] < 2:
        raise ValueError("RNA modality must have at least two genes for clustering.")

    log_norm = _normalize_log1p(rna_csr, target_sum=target_sum)
    pcs = _fit_pcs(log_norm, n_components=n_pcs, random_state=random_state)
    adata.obsm["X_rna_pca"] = pcs

    kmeans = KMeans(n_clusters=n_clusters, random_state=random_state, n_init=10)
    cluster_labels = kmeans.fit_predict(pcs)
    cluster_names = np.array([f"cluster_{i}" for i in cluster_labels])
    adata.obs[cluster_key] = cluster_names

    results: List[Dict[str, float]] = []
    models: Dict[str, Dict[str, ProteinModelResult]] = {}

    unique_clusters = pd.Index(np.unique(cluster_names))
    for cluster in unique_clusters:
        idx = np.where(cluster_names == cluster)[0]
        if idx.size < min_cells_per_cluster:
            continue

        cluster_pcs = pcs[idx]
        cluster_protein = protein_df.iloc[idx]
        models.setdefault(cluster, {})

        for protein in cluster_protein.columns:
            values_all = cluster_protein[protein].to_numpy(dtype=np.float32)
            finite_mask = np.isfinite(values_all)
            if finite_mask.sum() < min_cells_per_group * 2:
                continue

            values = values_all[finite_mask]
            X_cluster = cluster_pcs[finite_mask]

            if protein_split_method == "median":
                threshold = float(np.median(values))
                labels = (values >= threshold).astype(int)
            elif protein_split_method == "quantile":
                q = float(protein_quantile)
                if not 0.5 < q < 1.0:
                    raise ValueError("protein_quantile must be between 0.5 and 1.0 (exclusive).")
                high_thr = np.quantile(values, q)
                low_mask = values <= np.quantile(values, 1.0 - q)
                high_mask = values >= high_thr
                selected_mask = high_mask | low_mask
                if selected_mask.sum() < min_cells_per_group * 2:
                    continue
                values = values[selected_mask]
                X_cluster = X_cluster[selected_mask]
                labels = high_mask[selected_mask].astype(int)
                threshold = float(high_thr)
            else:
                raise ValueError("protein_split_method must be 'median' or 'quantile'.")

            n_selected = labels.size
            n_high = int(labels.sum())
            n_low = int(n_selected - n_high)

            if n_high < min_cells_per_group or n_low < min_cells_per_group:
                continue

            try:
                X_train, X_test, y_train, y_test = train_test_split(
                    X_cluster,
                    labels,
                    test_size=test_size,
                    random_state=random_state,
                    stratify=labels,
                )
            except ValueError:
                # Not enough samples to stratify.
                continue

            scaler = StandardScaler()
            X_train_scaled = scaler.fit_transform(X_train)
            X_test_scaled = scaler.transform(X_test)

            clf = MLPClassifier(
                hidden_layer_sizes=hidden_layer_sizes,
                random_state=random_state,
                max_iter=max_iter,
                early_stopping=early_stopping,
            )

            try:
                clf.fit(X_train_scaled, y_train)
            except Exception:
                continue

            train_acc = float(accuracy_score(y_train, clf.predict(X_train_scaled)))
            test_pred = clf.predict(X_test_scaled)
            test_acc = float(accuracy_score(y_test, test_pred))

            if hasattr(clf, "predict_proba") and len(np.unique(y_test)) == 2:
                probs = clf.predict_proba(X_test_scaled)[:, 1]
                try:
                    test_auc = float(roc_auc_score(y_test, probs))
                except ValueError:
                    test_auc = float("nan")
            else:
                test_auc = float("nan")

            result = ProteinModelResult(
                protein=protein,
                cluster=cluster,
                threshold=threshold,
                n_cells=n_selected,
                n_high=n_high,
                n_low=n_low,
                train_accuracy=train_acc,
                test_accuracy=test_acc,
                test_auc=test_auc,
                model=clf,
                scaler=scaler,
            )

            models[cluster][protein] = result
            results.append(
                {
                    "cluster": cluster,
                    "protein": protein,
                    "threshold": threshold,
                    "n_cells": n_selected,
                    "n_high": n_high,
                    "n_low": n_low,
                    "train_accuracy": train_acc,
                    "test_accuracy": test_acc,
                    "test_auc": test_auc,
                }
            )

    summary = pd.DataFrame(results, columns=[
        "cluster",
        "protein",
        "threshold",
        "n_cells",
        "n_high",
        "n_low",
        "train_accuracy",
        "test_accuracy",
        "test_auc",
    ])

    return summary, models

</code_context>

<issue_to_address>
**建议 (代码质量):** 我们发现了这些问题:

- 将条件语句简化为类似 switch 的形式 ([`switch`](https://docs.sourcery.ai/Reference/Default-Rules/refactorings/switch/))
- 移除不必要的 int, str, float 或 bool 类型转换 ([`remove-unnecessary-cast`](https://docs.sourcery.ai/Reference/Default-Rules/refactorings/remove-unnecessary-cast/))
- `rna_protein_cluster_analysis` 中发现代码质量低 - 6% ([`low-code-quality`](https://docs.sourcery.ai/Reference/Default-Rules/comments/low-code-quality/))

```suggestion
                q = protein_quantile
```

<br/><details><summary>解释</summary>

此函数的质量得分低于 25% 的质量阈值。
此得分是方法长度、认知复杂度和工作内存的组合。

你如何解决这个问题?

重构此函数以使其更短、更具可读性可能值得。

- 通过将功能片段提取到自己的函数中来减少函数长度。这是你能做的最重要的事情——理想情况下,一个函数应该少于 10 行。
- 减少嵌套,也许可以通过引入守卫子句来提前返回。
- 确保变量的范围紧密,以便使用相关概念的代码在函数中坐在一起,而不是分散开来。</details>
</issue_to_address>

Sourcery 对开源项目免费 - 如果你喜欢我们的评论,请考虑分享它们 ✨
帮助我更有用!请对每条评论点击 👍 或 👎,我将使用这些反馈来改进你的评论。
Original comment in English

Hey there - I've reviewed your changes and they look great!

Prompt for AI Agents
Please address the comments from this code review:

## Individual Comments

### Comment 1
<location> `src/pyXenium/analysis/rna_protein_joint.py:290-291` </location>
<code_context>
+
+            try:
+                clf.fit(X_train_scaled, y_train)
+            except Exception:
+                continue
+
</code_context>

<issue_to_address>
**suggestion (bug_risk):** Catching all exceptions during model fitting may obscure underlying issues.

Catch only specific exceptions like ValueError or ConvergenceWarning to ensure critical errors are not missed during development.

```suggestion
            from sklearn.exceptions import ConvergenceWarning
            import warnings

            try:
                with warnings.catch_warnings():
                    warnings.filterwarnings("error", category=ConvergenceWarning)
                    clf.fit(X_train_scaled, y_train)
            except (ValueError, ConvergenceWarning):
                continue
```
</issue_to_address>

### Comment 2
<location> `src/pyXenium/analysis/rna_protein_joint.py:275` </location>
<code_context>
+                    stratify=labels,
+                )
+            except ValueError:
+                # Not enough samples to stratify.
+                continue
+
</code_context>

<issue_to_address>
**suggestion (bug_risk):** Silently continuing on ValueError in train_test_split may skip clusters without notification.

Add logging when clusters are skipped due to insufficient samples for stratification to improve traceability.

```suggestion
            except ValueError:
                import logging
                logging.warning(
                    "Skipping cluster due to insufficient samples for stratification (cluster size: %d, labels: %s)",
                    len(X_cluster), set(labels)
                )
```
</issue_to_address>

### Comment 3
<location> `src/pyXenium/analysis/rna_protein_joint.py:94-95` </location>
<code_context>
+def _resolve_protein_frame(adata: AnnData) -> pd.DataFrame:
+    """Return the protein matrix stored in ``adata.obsm['protein']`` as DataFrame."""
+
+    if "protein" not in adata.obsm:
+        raise KeyError("AnnData is missing 'protein' modality in adata.obsm['protein'].")
+
</code_context>

<issue_to_address>
**suggestion:** KeyError is raised if 'protein' is missing, but no fallback or guidance is provided.

Consider updating the error message to include instructions for adding or computing the 'protein' modality, making it easier for users to address the issue.

```suggestion
    if "protein" not in adata.obsm:
        raise KeyError(
            "AnnData is missing 'protein' modality in adata.obsm['protein'].\n"
            "To resolve this, ensure that the protein data is loaded or computed and assigned to adata.obsm['protein'].\n"
            "For example, use pyXenium.io.xenium_gene_protein_loader.load_xenium_gene_protein or another appropriate method to add the protein modality."
        )
```
</issue_to_address>

### Comment 4
<location> `src/pyXenium/analysis/rna_protein_joint.py:231` </location>
<code_context>
+        cluster_protein = protein_df.iloc[idx]
+        models.setdefault(cluster, {})
+
+        for protein in cluster_protein.columns:
+            values_all = cluster_protein[protein].to_numpy(dtype=np.float32)
+            finite_mask = np.isfinite(values_all)
</code_context>

<issue_to_address>
**issue (complexity):** Consider extracting the protein data preparation and classifier fitting logic into helper functions to flatten the main loop and reduce cyclomatic complexity.

Here are two small helpers you can extract from the inner loops to flatten that nesting, reduce cyclomatic complexity, and keep all functionality:

```python
def _prepare_protein_data(
    values: np.ndarray,
    X_cluster: np.ndarray,
    method: str,
    quantile: float,
    min_cells: int,
) -> Tuple[Optional[float], np.ndarray, np.ndarray]:
    """Return threshold, filtered X and labels or (None,_,_) if not enough cells."""
    finite = np.isfinite(values)
    if finite.sum() < min_cells * 2:
        return None, None, None

    vals = values[finite]
    Xc  = X_cluster[finite]

    if method == "median":
        thr    = float(np.median(vals))
        labels = (vals >= thr).astype(int)
    elif method == "quantile":
        q = float(quantile)
        if not 0.5 < q < 1.0:
            raise ValueError("protein_quantile must be between 0.5 and 1.0")
        high_thr = np.quantile(vals, q)
        low_mask  = vals <= np.quantile(vals, 1 - q)
        high_mask = vals >= high_thr
        sel       = low_mask | high_mask
        if sel.sum() < min_cells * 2:
            return None, None, None
        vals   = vals[sel]
        Xc     = Xc[sel]
        labels = high_mask[sel].astype(int)
        thr    = float(high_thr)
    else:
        raise ValueError("protein_split_method must be 'median' or 'quantile'")

    if labels.sum() < min_cells or (len(labels) - labels.sum()) < min_cells:
        return None, None, None

    return thr, Xc, labels
```

```python
def _fit_protein_classifier(
    X: np.ndarray,
    y: np.ndarray,
    test_size: float,
    random_state: Optional[int],
    hidden_layer_sizes: Tuple[int, ...],
    max_iter: int,
    early_stopping: bool,
) -> Optional[ProteinModelResult]:
    """Train/test split + scaler + MLPClassifier, return result or None on error."""
    try:
        X_train, X_test, y_train, y_test = train_test_split(
            X, y, test_size=test_size, stratify=y, random_state=random_state
        )
    except ValueError:
        return None

    scaler = StandardScaler()
    Xtr = scaler.fit_transform(X_train)
    Xte = scaler.transform(X_test)

    clf = MLPClassifier(
        hidden_layer_sizes=hidden_layer_sizes,
        random_state=random_state,
        max_iter=max_iter,
        early_stopping=early_stopping,
    )
    try:
        clf.fit(Xtr, y_train)
    except Exception:
        return None

    train_acc = accuracy_score(y_train, clf.predict(Xtr))
    y_pred    = clf.predict(Xte)
    test_acc  = accuracy_score(y_test, y_pred)
    if hasattr(clf, "predict_proba") and len(np.unique(y_test)) == 2:
        try:
            auc = roc_auc_score(y_test, clf.predict_proba(Xte)[:, 1])
        except ValueError:
            auc = float("nan")
    else:
        auc = float("nan")

    return ProteinModelResult(
        protein="",
        cluster="",
        threshold=0.0,
        n_cells=len(y),
        n_high=int(y.sum()),
        n_low=int(len(y) - y.sum()),
        train_accuracy=float(train_acc),
        test_accuracy=float(test_acc),
        test_auc=float(auc),
        model=clf,
        scaler=scaler,
    )
```

Then in your main loop replace the big block with:

```python
for protein in cluster_protein.columns:
    vals = cluster_protein[protein].to_numpy(dtype=np.float32)
    thr, X_sel, labels = _prepare_protein_data(
        vals, cluster_pcs, protein_split_method, protein_quantile, min_cells_per_group
    )
    if thr is None:
        continue

    result = _fit_protein_classifier(
        X_sel, labels, test_size, random_state,
        hidden_layer_sizes, max_iter, early_stopping,
    )
    if result is None:
        continue

    result.protein = protein
    result.cluster = cluster
    result.threshold = thr

    models[cluster][protein] = result
    results.append({
        "cluster": cluster,
        "protein": protein,
        "threshold": thr,
        "n_cells": result.n_cells,
        "n_high": result.n_high,
        "n_low": result.n_low,
        "train_accuracy": result.train_accuracy,
        "test_accuracy": result.test_accuracy,
        "test_auc": result.test_auc,
    })
```

This pulls out ~60 lines of nested logic into two focused helpers and flattens your main loop.
</issue_to_address>

### Comment 5
<location> `src/pyXenium/analysis/rna_protein_joint.py:50-59` </location>
<code_context>
def _get_rna_matrix(adata: AnnData):
    """Return the raw RNA matrix from ``adata`` as CSR sparse matrix."""

    if "rna" in adata.layers:
        X = adata.layers["rna"]
    else:
        X = adata.X

    if sparse.issparse(X):
        return X.tocsr()

    # Dense array – convert to CSR to keep operations memory friendly.
    return sparse.csr_matrix(np.asarray(X))

</code_context>

<issue_to_address>
**suggestion (code-quality):** We've found these issues:

- Replace if statement with if expression [×2] ([`assign-if-exp`](https://docs.sourcery.ai/Reference/Default-Rules/refactorings/assign-if-exp/))
- Lift code into else after jump in control flow ([`reintroduce-else`](https://docs.sourcery.ai/Reference/Default-Rules/refactorings/reintroduce-else/))

```suggestion
    X = adata.layers["rna"] if "rna" in adata.layers else adata.X
    return X.tocsr() if sparse.issparse(X) else sparse.csr_matrix(np.asarray(X))
```
</issue_to_address>

### Comment 6
<location> `src/pyXenium/analysis/rna_protein_joint.py:244` </location>
<code_context>
def rna_protein_cluster_analysis(
    adata: AnnData,
    *,
    n_clusters: int = 12,
    n_pcs: int = 30,
    cluster_key: str = "rna_cluster",
    random_state: Optional[int] = 0,
    target_sum: float = 1e4,
    min_cells_per_cluster: int = 50,
    min_cells_per_group: int = 20,
    protein_split_method: str = "median",
    protein_quantile: float = 0.75,
    test_size: float = 0.2,
    hidden_layer_sizes: Tuple[int, ...] = (64, 32),
    max_iter: int = 200,
    early_stopping: bool = True,
) -> Tuple[pd.DataFrame, Dict[str, Dict[str, ProteinModelResult]]]:
    """Joint RNA/protein analysis for Xenium AnnData objects.

    The pipeline performs three consecutive steps:

    1. **RNA preprocessing** – library-size normalisation (counts per ``target_sum``)
       followed by ``log1p``.  A :class:`~sklearn.decomposition.TruncatedSVD`
       is fitted to obtain ``n_pcs`` latent dimensions.
    2. **Clustering** – :class:`~sklearn.cluster.KMeans` is applied on the latent
       representation to create ``n_clusters`` RNA-driven cell groups.  Cluster
       assignments are stored in ``adata.obs[cluster_key]`` and the latent space
       in ``adata.obsm['X_rna_pca']``.
    3. **Protein explanation** – for every cluster and every protein marker, the
       cells are divided into "high" vs. "low" groups (median split by default).
       A small neural network (:class:`~sklearn.neural_network.MLPClassifier`)
       is trained to predict the binary labels from the RNA latent features.  The
       training/test accuracies and optional ROC-AUC are reported.

    Parameters
    ----------
    adata:
        AnnData object returned by
        :func:`pyXenium.io.xenium_gene_protein_loader.load_xenium_gene_protein`.
        Requires ``adata.layers['rna']`` (or ``adata.X``) and
        ``adata.obsm['protein']``.
    n_clusters:
        Number of RNA clusters to compute with KMeans.
    n_pcs:
        Number of latent components extracted with TruncatedSVD.  The value is
        automatically capped at ``n_genes - 1``.
    cluster_key:
        Column name added to ``adata.obs`` that stores cluster labels.
    random_state:
        Seed for the SVD, KMeans and neural networks.  Use ``None`` for random
        initialisation.
    target_sum:
        Target library size after normalisation (Counts Per ``target_sum``).
    min_cells_per_cluster:
        Clusters with fewer cells are skipped entirely.
    min_cells_per_group:
        Minimum number of cells required in both "high" and "low" protein
        groups to train a neural network.
    protein_split_method:
        Either ``"median"`` (default) for a median split or ``"quantile"`` to
        keep only the top ``protein_quantile`` and bottom ``1 - protein_quantile``
        fractions of cells (discarding the middle portion).
    protein_quantile:
        Quantile used when ``protein_split_method='quantile'``.
    test_size:
        Fraction of the cluster reserved for the test split when training the
        neural network.
    hidden_layer_sizes:
        Hidden-layer configuration passed to :class:`MLPClassifier`.
    max_iter:
        Maximum number of training iterations for the neural network.
    early_stopping:
        Whether to use early stopping in :class:`MLPClassifier`.

    Returns
    -------
    summary:
        :class:`pandas.DataFrame` summarising the trained models.  Columns are
        ``['cluster', 'protein', 'threshold', 'n_cells', 'n_high', 'n_low',
        'train_accuracy', 'test_accuracy', 'test_auc']``.
    models:
        Nested dictionary ``{cluster -> {protein -> ProteinModelResult}}``
        containing the fitted neural networks and scalers for downstream use.

    Examples
    --------
    >>> from pyXenium.analysis import rna_protein_cluster_analysis
    >>> summary, models = rna_protein_cluster_analysis(adata, n_clusters=8)
    >>> summary.head()
          cluster          protein  threshold  n_cells  ...  test_accuracy  test_auc
    0    cluster_0      EPCAM (µm)   0.563100      512  ...           0.84      0.91
    1    cluster_0  Podocin (µm^2)   0.118775      512  ...           0.79      0.87
    """

    if adata.n_obs == 0:
        raise ValueError("AnnData contains no cells (n_obs == 0).")

    protein_df = _resolve_protein_frame(adata)
    if protein_df.shape[1] == 0:
        raise ValueError("AnnData.obsm['protein'] is empty – nothing to analyse.")

    rna_csr = _get_rna_matrix(adata)
    if rna_csr.shape[1] < 2:
        raise ValueError("RNA modality must have at least two genes for clustering.")

    log_norm = _normalize_log1p(rna_csr, target_sum=target_sum)
    pcs = _fit_pcs(log_norm, n_components=n_pcs, random_state=random_state)
    adata.obsm["X_rna_pca"] = pcs

    kmeans = KMeans(n_clusters=n_clusters, random_state=random_state, n_init=10)
    cluster_labels = kmeans.fit_predict(pcs)
    cluster_names = np.array([f"cluster_{i}" for i in cluster_labels])
    adata.obs[cluster_key] = cluster_names

    results: List[Dict[str, float]] = []
    models: Dict[str, Dict[str, ProteinModelResult]] = {}

    unique_clusters = pd.Index(np.unique(cluster_names))
    for cluster in unique_clusters:
        idx = np.where(cluster_names == cluster)[0]
        if idx.size < min_cells_per_cluster:
            continue

        cluster_pcs = pcs[idx]
        cluster_protein = protein_df.iloc[idx]
        models.setdefault(cluster, {})

        for protein in cluster_protein.columns:
            values_all = cluster_protein[protein].to_numpy(dtype=np.float32)
            finite_mask = np.isfinite(values_all)
            if finite_mask.sum() < min_cells_per_group * 2:
                continue

            values = values_all[finite_mask]
            X_cluster = cluster_pcs[finite_mask]

            if protein_split_method == "median":
                threshold = float(np.median(values))
                labels = (values >= threshold).astype(int)
            elif protein_split_method == "quantile":
                q = float(protein_quantile)
                if not 0.5 < q < 1.0:
                    raise ValueError("protein_quantile must be between 0.5 and 1.0 (exclusive).")
                high_thr = np.quantile(values, q)
                low_mask = values <= np.quantile(values, 1.0 - q)
                high_mask = values >= high_thr
                selected_mask = high_mask | low_mask
                if selected_mask.sum() < min_cells_per_group * 2:
                    continue
                values = values[selected_mask]
                X_cluster = X_cluster[selected_mask]
                labels = high_mask[selected_mask].astype(int)
                threshold = float(high_thr)
            else:
                raise ValueError("protein_split_method must be 'median' or 'quantile'.")

            n_selected = labels.size
            n_high = int(labels.sum())
            n_low = int(n_selected - n_high)

            if n_high < min_cells_per_group or n_low < min_cells_per_group:
                continue

            try:
                X_train, X_test, y_train, y_test = train_test_split(
                    X_cluster,
                    labels,
                    test_size=test_size,
                    random_state=random_state,
                    stratify=labels,
                )
            except ValueError:
                # Not enough samples to stratify.
                continue

            scaler = StandardScaler()
            X_train_scaled = scaler.fit_transform(X_train)
            X_test_scaled = scaler.transform(X_test)

            clf = MLPClassifier(
                hidden_layer_sizes=hidden_layer_sizes,
                random_state=random_state,
                max_iter=max_iter,
                early_stopping=early_stopping,
            )

            try:
                clf.fit(X_train_scaled, y_train)
            except Exception:
                continue

            train_acc = float(accuracy_score(y_train, clf.predict(X_train_scaled)))
            test_pred = clf.predict(X_test_scaled)
            test_acc = float(accuracy_score(y_test, test_pred))

            if hasattr(clf, "predict_proba") and len(np.unique(y_test)) == 2:
                probs = clf.predict_proba(X_test_scaled)[:, 1]
                try:
                    test_auc = float(roc_auc_score(y_test, probs))
                except ValueError:
                    test_auc = float("nan")
            else:
                test_auc = float("nan")

            result = ProteinModelResult(
                protein=protein,
                cluster=cluster,
                threshold=threshold,
                n_cells=n_selected,
                n_high=n_high,
                n_low=n_low,
                train_accuracy=train_acc,
                test_accuracy=test_acc,
                test_auc=test_auc,
                model=clf,
                scaler=scaler,
            )

            models[cluster][protein] = result
            results.append(
                {
                    "cluster": cluster,
                    "protein": protein,
                    "threshold": threshold,
                    "n_cells": n_selected,
                    "n_high": n_high,
                    "n_low": n_low,
                    "train_accuracy": train_acc,
                    "test_accuracy": test_acc,
                    "test_auc": test_auc,
                }
            )

    summary = pd.DataFrame(results, columns=[
        "cluster",
        "protein",
        "threshold",
        "n_cells",
        "n_high",
        "n_low",
        "train_accuracy",
        "test_accuracy",
        "test_auc",
    ])

    return summary, models

</code_context>

<issue_to_address>
**suggestion (code-quality):** We've found these issues:

- Simplify conditional into switch-like form ([`switch`](https://docs.sourcery.ai/Reference/Default-Rules/refactorings/switch/))
- Remove unnecessary casts to int, str, float or bool ([`remove-unnecessary-cast`](https://docs.sourcery.ai/Reference/Default-Rules/refactorings/remove-unnecessary-cast/))
- Low code quality found in rna\_protein\_cluster\_analysis - 6% ([`low-code-quality`](https://docs.sourcery.ai/Reference/Default-Rules/comments/low-code-quality/))

```suggestion
                q = protein_quantile
```

<br/><details><summary>Explanation</summary>

The quality score for this function is below the quality threshold of 25%.
This score is a combination of the method length, cognitive complexity and working memory.

How can you solve this?

It might be worth refactoring this function to make it shorter and more readable.

- Reduce the function length by extracting pieces of functionality out into
  their own functions. This is the most important thing you can do - ideally a
  function should be less than 10 lines.
- Reduce nesting, perhaps by introducing guard clauses to return early.
- Ensure that variables are tightly scoped, so that code using related concepts
  sits together within the function rather than being scattered.</details>
</issue_to_address>

Sourcery is free for open source - if you like our reviews please consider sharing them ✨
Help me be more useful! Please click 👍 or 👎 on each comment and I'll use the feedback to improve your reviews.

Comment on lines +290 to +291
try:
clf.fit(X_train_scaled, y_train)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

建议 (bug_risk): 在模型拟合期间捕获所有异常可能会掩盖潜在问题。

只捕获特定的异常,例如 ValueError 或 ConvergenceWarning,以确保在开发过程中不会遗漏关键错误。

Suggested change
try:
clf.fit(X_train_scaled, y_train)
from sklearn.exceptions import ConvergenceWarning
import warnings
try:
with warnings.catch_warnings():
warnings.filterwarnings("error", category=ConvergenceWarning)
clf.fit(X_train_scaled, y_train)
except (ValueError, ConvergenceWarning):
continue
Original comment in English

suggestion (bug_risk): Catching all exceptions during model fitting may obscure underlying issues.

Catch only specific exceptions like ValueError or ConvergenceWarning to ensure critical errors are not missed during development.

Suggested change
try:
clf.fit(X_train_scaled, y_train)
from sklearn.exceptions import ConvergenceWarning
import warnings
try:
with warnings.catch_warnings():
warnings.filterwarnings("error", category=ConvergenceWarning)
clf.fit(X_train_scaled, y_train)
except (ValueError, ConvergenceWarning):
continue

random_state=random_state,
stratify=labels,
)
except ValueError:
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

建议 (bug_risk):train_test_split 中遇到 ValueError 时静默继续可能会在不通知的情况下跳过集群。

当由于分层样本不足而跳过集群时,添加日志记录以提高可追溯性。

Suggested change
except ValueError:
except ValueError:
import logging
logging.warning(
"Skipping cluster due to insufficient samples for stratification (cluster size: %d, labels: %s)",
len(X_cluster), set(labels)
)
Original comment in English

suggestion (bug_risk): Silently continuing on ValueError in train_test_split may skip clusters without notification.

Add logging when clusters are skipped due to insufficient samples for stratification to improve traceability.

Suggested change
except ValueError:
except ValueError:
import logging
logging.warning(
"Skipping cluster due to insufficient samples for stratification (cluster size: %d, labels: %s)",
len(X_cluster), set(labels)
)

Comment on lines +94 to +95
if "protein" not in adata.obsm:
raise KeyError("AnnData is missing 'protein' modality in adata.obsm['protein'].")
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

建议: 如果缺少 'protein',会引发 KeyError,但没有提供备用方案或指导。

考虑更新错误消息,包含添加或计算 'protein' 模态的说明,以便用户更容易解决此问题。

Suggested change
if "protein" not in adata.obsm:
raise KeyError("AnnData is missing 'protein' modality in adata.obsm['protein'].")
if "protein" not in adata.obsm:
raise KeyError(
"AnnData is missing 'protein' modality in adata.obsm['protein'].\n"
"To resolve this, ensure that the protein data is loaded or computed and assigned to adata.obsm['protein'].\n"
"For example, use pyXenium.io.xenium_gene_protein_loader.load_xenium_gene_protein or another appropriate method to add the protein modality."
)
Original comment in English

suggestion: KeyError is raised if 'protein' is missing, but no fallback or guidance is provided.

Consider updating the error message to include instructions for adding or computing the 'protein' modality, making it easier for users to address the issue.

Suggested change
if "protein" not in adata.obsm:
raise KeyError("AnnData is missing 'protein' modality in adata.obsm['protein'].")
if "protein" not in adata.obsm:
raise KeyError(
"AnnData is missing 'protein' modality in adata.obsm['protein'].\n"
"To resolve this, ensure that the protein data is loaded or computed and assigned to adata.obsm['protein'].\n"
"For example, use pyXenium.io.xenium_gene_protein_loader.load_xenium_gene_protein or another appropriate method to add the protein modality."
)

cluster_protein = protein_df.iloc[idx]
models.setdefault(cluster, {})

for protein in cluster_protein.columns:
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

问题 (复杂性): 考虑将蛋白质数据准备和分类器拟合逻辑提取到辅助函数中,以简化主循环并降低循环复杂度。

这里有两个你可以从内循环中提取出来的小型辅助函数,以简化嵌套、降低循环复杂度并保持所有功能:

def _prepare_protein_data(
    values: np.ndarray,
    X_cluster: np.ndarray,
    method: str,
    quantile: float,
    min_cells: int,
) -> Tuple[Optional[float], np.ndarray, np.ndarray]:
    """Return threshold, filtered X and labels or (None,_,_) if not enough cells."""
    finite = np.isfinite(values)
    if finite.sum() < min_cells * 2:
        return None, None, None

    vals = values[finite]
    Xc  = X_cluster[finite]

    if method == "median":
        thr    = float(np.median(vals))
        labels = (vals >= thr).astype(int)
    elif method == "quantile":
        q = float(quantile)
        if not 0.5 < q < 1.0:
            raise ValueError("protein_quantile must be between 0.5 and 1.0")
        high_thr = np.quantile(vals, q)
        low_mask  = vals <= np.quantile(vals, 1 - q)
        high_mask = vals >= high_thr
        sel       = low_mask | high_mask
        if sel.sum() < min_cells * 2:
            return None, None, None
        vals   = vals[sel]
        Xc     = Xc[sel]
        labels = high_mask[sel].astype(int)
        thr    = float(high_thr)
    else:
        raise ValueError("protein_split_method must be 'median' or 'quantile'")

    if labels.sum() < min_cells or (len(labels) - labels.sum()) < min_cells:
        return None, None, None

    return thr, Xc, labels
def _fit_protein_classifier(
    X: np.ndarray,
    y: np.ndarray,
    test_size: float,
    random_state: Optional[int],
    hidden_layer_sizes: Tuple[int, ...],
    max_iter: int,
    early_stopping: bool,
) -> Optional[ProteinModelResult]:
    """Train/test split + scaler + MLPClassifier, return result or None on error."""
    try:
        X_train, X_test, y_train, y_test = train_test_split(
            X, y, test_size=test_size, stratify=y, random_state=random_state
        )
    except ValueError:
        return None

    scaler = StandardScaler()
    Xtr = scaler.fit_transform(X_train)
    Xte = scaler.transform(X_test)

    clf = MLPClassifier(
        hidden_layer_sizes=hidden_layer_sizes,
        random_state=random_state,
        max_iter=max_iter,
        early_stopping=early_stopping,
    )
    try:
        clf.fit(Xtr, y_train)
    except Exception:
        return None

    train_acc = accuracy_score(y_train, clf.predict(Xtr))
    y_pred    = clf.predict(Xte)
    test_acc  = accuracy_score(y_test, y_pred)
    if hasattr(clf, "predict_proba") and len(np.unique(y_test)) == 2:
        try:
            auc = roc_auc_score(y_test, clf.predict_proba(Xte)[:, 1])
        except ValueError:
            auc = float("nan")
    else:
        auc = float("nan")

    return ProteinModelResult(
        protein="",
        cluster="",
        threshold=0.0,
        n_cells=len(y),
        n_high=int(y.sum()),
        n_low=int(len(y) - y.sum()),
        train_accuracy=float(train_acc),
        test_accuracy=float(test_acc),
        test_auc=float(auc),
        model=clf,
        scaler=scaler,
    )

然后在主循环中将大块代码替换为:

for protein in cluster_protein.columns:
    vals = cluster_protein[protein].to_numpy(dtype=np.float32)
    thr, X_sel, labels = _prepare_protein_data(
        vals, cluster_pcs, protein_split_method, protein_quantile, min_cells_per_group
    )
    if thr is None:
        continue

    result = _fit_protein_classifier(
        X_sel, labels, test_size, random_state,
        hidden_layer_sizes, max_iter, early_stopping,
    )
    if result is None:
        continue

    result.protein = protein
    result.cluster = cluster
    result.threshold = thr

    models[cluster][protein] = result
    results.append({
        "cluster": cluster,
        "protein": protein,
        "threshold": thr,
        "n_cells": result.n_cells,
        "n_high": result.n_high,
        "n_low": result.n_low,
        "train_accuracy": result.train_accuracy,
        "test_accuracy": result.test_accuracy,
        "test_auc": result.test_auc,
    })

这将大约 60 行嵌套逻辑提取到两个专注的辅助函数中,并简化了你的主循环。

Original comment in English

issue (complexity): Consider extracting the protein data preparation and classifier fitting logic into helper functions to flatten the main loop and reduce cyclomatic complexity.

Here are two small helpers you can extract from the inner loops to flatten that nesting, reduce cyclomatic complexity, and keep all functionality:

def _prepare_protein_data(
    values: np.ndarray,
    X_cluster: np.ndarray,
    method: str,
    quantile: float,
    min_cells: int,
) -> Tuple[Optional[float], np.ndarray, np.ndarray]:
    """Return threshold, filtered X and labels or (None,_,_) if not enough cells."""
    finite = np.isfinite(values)
    if finite.sum() < min_cells * 2:
        return None, None, None

    vals = values[finite]
    Xc  = X_cluster[finite]

    if method == "median":
        thr    = float(np.median(vals))
        labels = (vals >= thr).astype(int)
    elif method == "quantile":
        q = float(quantile)
        if not 0.5 < q < 1.0:
            raise ValueError("protein_quantile must be between 0.5 and 1.0")
        high_thr = np.quantile(vals, q)
        low_mask  = vals <= np.quantile(vals, 1 - q)
        high_mask = vals >= high_thr
        sel       = low_mask | high_mask
        if sel.sum() < min_cells * 2:
            return None, None, None
        vals   = vals[sel]
        Xc     = Xc[sel]
        labels = high_mask[sel].astype(int)
        thr    = float(high_thr)
    else:
        raise ValueError("protein_split_method must be 'median' or 'quantile'")

    if labels.sum() < min_cells or (len(labels) - labels.sum()) < min_cells:
        return None, None, None

    return thr, Xc, labels
def _fit_protein_classifier(
    X: np.ndarray,
    y: np.ndarray,
    test_size: float,
    random_state: Optional[int],
    hidden_layer_sizes: Tuple[int, ...],
    max_iter: int,
    early_stopping: bool,
) -> Optional[ProteinModelResult]:
    """Train/test split + scaler + MLPClassifier, return result or None on error."""
    try:
        X_train, X_test, y_train, y_test = train_test_split(
            X, y, test_size=test_size, stratify=y, random_state=random_state
        )
    except ValueError:
        return None

    scaler = StandardScaler()
    Xtr = scaler.fit_transform(X_train)
    Xte = scaler.transform(X_test)

    clf = MLPClassifier(
        hidden_layer_sizes=hidden_layer_sizes,
        random_state=random_state,
        max_iter=max_iter,
        early_stopping=early_stopping,
    )
    try:
        clf.fit(Xtr, y_train)
    except Exception:
        return None

    train_acc = accuracy_score(y_train, clf.predict(Xtr))
    y_pred    = clf.predict(Xte)
    test_acc  = accuracy_score(y_test, y_pred)
    if hasattr(clf, "predict_proba") and len(np.unique(y_test)) == 2:
        try:
            auc = roc_auc_score(y_test, clf.predict_proba(Xte)[:, 1])
        except ValueError:
            auc = float("nan")
    else:
        auc = float("nan")

    return ProteinModelResult(
        protein="",
        cluster="",
        threshold=0.0,
        n_cells=len(y),
        n_high=int(y.sum()),
        n_low=int(len(y) - y.sum()),
        train_accuracy=float(train_acc),
        test_accuracy=float(test_acc),
        test_auc=float(auc),
        model=clf,
        scaler=scaler,
    )

Then in your main loop replace the big block with:

for protein in cluster_protein.columns:
    vals = cluster_protein[protein].to_numpy(dtype=np.float32)
    thr, X_sel, labels = _prepare_protein_data(
        vals, cluster_pcs, protein_split_method, protein_quantile, min_cells_per_group
    )
    if thr is None:
        continue

    result = _fit_protein_classifier(
        X_sel, labels, test_size, random_state,
        hidden_layer_sizes, max_iter, early_stopping,
    )
    if result is None:
        continue

    result.protein = protein
    result.cluster = cluster
    result.threshold = thr

    models[cluster][protein] = result
    results.append({
        "cluster": cluster,
        "protein": protein,
        "threshold": thr,
        "n_cells": result.n_cells,
        "n_high": result.n_high,
        "n_low": result.n_low,
        "train_accuracy": result.train_accuracy,
        "test_accuracy": result.test_accuracy,
        "test_auc": result.test_auc,
    })

This pulls out ~60 lines of nested logic into two focused helpers and flattens your main loop.

Comment thread src/pyXenium/analysis/rna_protein_joint.py Outdated
threshold = float(np.median(values))
labels = (values >= threshold).astype(int)
elif protein_split_method == "quantile":
q = float(protein_quantile)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

建议 (代码质量): 我们发现了这些问题:

Suggested change
q = float(protein_quantile)
q = protein_quantile


解释

此函数的质量得分低于 25% 的质量阈值。
此得分是方法长度、认知复杂度和工作内存的组合。

你如何解决这个问题?

重构此函数以使其更短、更具可读性可能值得。

  • 通过将功能片段提取到自己的函数中来减少函数长度。这是你能做的最重要的事情——理想情况下,一个函数应该少于 10 行。
  • 减少嵌套,也许可以通过引入守卫子句来提前返回。
  • 确保变量的范围紧密,以便使用相关概念的代码在函数中坐在一起,而不是分散开来。
Original comment in English

suggestion (code-quality): We've found these issues:

Suggested change
q = float(protein_quantile)
q = protein_quantile


Explanation

The quality score for this function is below the quality threshold of 25%.
This score is a combination of the method length, cognitive complexity and working memory.

How can you solve this?

It might be worth refactoring this function to make it shorter and more readable.

  • Reduce the function length by extracting pieces of functionality out into
    their own functions. This is the most important thing you can do - ideally a
    function should be less than 10 lines.
  • Reduce nesting, perhaps by introducing guard clauses to return early.
  • Ensure that variables are tightly scoped, so that code using related concepts
    sits together within the function rather than being scattered.

Co-authored-by: sourcery-ai[bot] <58596630+sourcery-ai[bot]@users.noreply.github.com>
@hutaobo hutaobo merged commit 14f4190 into main Oct 2, 2025
1 of 5 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant